import argparse
import toml
import copy

from factory import *


def run_seed(cfg_):
    experiment_dir = make_dir(cfg_["root_dir"], cfg_["experiment_name"], extra="random_seeds")

    if "device_idx" in cfg_.keys() and not None:
        device = set_device(cfg_["device_idx"])
    else:
        device = set_device()

    reporter = make_reporter(experiment_dir=experiment_dir, cfg=cfg_)

    set_seed(cfg_["random_seed"])

    data, data_loaders = make_dataset(root_dir=cfg_["root_dir"], dataset_name=cfg_["dataset_name"],
                                      params=cfg_["loader"])

    model = make_model(model_name=cfg_["model_name"], data=data, device=device, params=cfg_["model"])

    optimizer = make_optimizer(model=model, params=cfg_["optimizer"])

    experiment = make_experiment(model=model, optimizer=optimizer, reporter=reporter, device=device,
                                 experiment_dir=experiment_dir, params=cfg_["experiment"])

    experiment.train(**data_loaders)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--root_dir', type=str, default="../")
    parser.add_argument('-d', '--dataset_name', type=str, default="reddit")
    parser.add_argument('-i', '--device_idx', type=int)
    parser.add_argument('-c', '--config_file', type=str, default="debug.toml")

    args = parser.parse_args()
    with open(args.root_dir + "/configs/" + args.config_file, mode="r") as f:
        cfg = toml.load(f)

    cfg["root_dir"] = args.root_dir
    cfg["dataset_name"] = args.dataset_name
    cfg["device_idx"] = args.device_idx

    for i in range(len(cfg["random_seed"])):
        cfg_ = copy.deepcopy(cfg)
        cfg_["random_seed"] = i
        run_seed(cfg_)


